Add weight tying support for Llama3#2580
Merged
tianyu-l merged 3 commits intopytorch:mainfrom Mar 24, 2026
Merged
Conversation
tianyu-l
requested changes
Mar 15, 2026
Contributor
tianyu-l
left a comment
There was a problem hiding this comment.
IIUC the existing model registry doesn't have llama3.2 1B / 3B models, which are the only variants which have weight-tying enabled. Please add those models to llama3/__init__.py. You can refer to the exact config in earlier attempt #1376
Ties tok_embeddings.weight to output.weight via enable_weight_tying config flag. Follows the same pattern as Qwen3 (pytorch#1590). Closes pytorch#1524.
Llama 3.2 1B and 3B are the only Llama variants with weight tying, so they belong in the registry. Without them the feature has no real entry point. Also dropped the try/except guard in test_weight_tying.py, which was inconsistent with every other unit test here and silently skips on broken imports.
8d7f787 to
ceda986
Compare
Contributor
|
please fix tests |
pytorch-bot bot
pushed a commit
that referenced
this pull request
Mar 27, 2026
Implements enable_weight_tying for Llama3, sharing tok_embeddings.weight with output.weight. It mirrors the Qwen3 implementation from #1590 (thanks!) Changes cover model.py (config field, tying in __init__/init_weights, PP guard), parallelize.py (grouped FSDP unit for tied params), state_dict_adapter.py (skip/reconstruct output.weight for HF conversion), and a new unit test file Closes #1524
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Implements enable_weight_tying for Llama3, sharing tok_embeddings.weight with output.weight. It mirrors the Qwen3 implementation from #1590 (thanks!)
Changes cover model.py (config field, tying in init/init_weights, PP guard), parallelize.py (grouped FSDP unit for tied params), state_dict_adapter.py (skip/reconstruct output.weight for HF conversion), and a new unit test file
Closes #1524